Hierarchical models

Theory: what is a hierarchical model?

In general: a model with hyperparameters, i.e. parameters that probabilistically control other parameters.

E.g.

\[\begin{align*} y_i &\sim N(\alpha_{group(i)} + \beta \cdot x_i, \sigma) \\ \alpha{group(i)} &\sim N(\mu, \tau) \end{align*}\]

In this model \(\tau\) is a hyperparameter. Is \(\mu\) a hyperparameter???

Hierarchical models are great for describing the situation where you know some measurements have something in common (e.g. they come from the same group), but you don’t know how much.

Learn more!

Example: always be closing!

Plushycorp employs 10 salespeople who go door to door selling cute plushies. The number of plushies that each salesperson sold every working day for two weeks was recorded. What can Plushycorp find out from this data?

To answer the question in a best-case scenario, we can use a hierarchical model to run a “digital twin” of this experiment with known parameters and data generating process. Specifically, we can assume that the number \(y_{ij}\) of plushies that salesperson \(i\) sells on day \(j\) depends on a combination of factors:

  • The baseline amount \(\mu\) that a totally average salesperson would sell on a normal day
  • The salesperson’s ability \(ability_i\)
  • An effect \(day\ effect_j\) for the day of the week: people are thought to buy fewer and fewer plushies as the week drags on.
  • Some random variation

A good first step for modelling count data is the Poisson distribution, so let’s assume that the sales measurements follow the following Poisson distribution:1

1 Note the use of the log link function.

\[\begin{align*} y_{ij} &\sim Poisson(\lambda) \\ \ln\lambda &= \mu + ability_i + day\ effect_j \end{align*}\]

We know that the salespeople have different abilities, but how just different are they? Since this isn’t really clear to Plushycorp, it makes sense to introduce a parameter \(\tau_{ability}\) into the model:

\[\begin{equation*} ability \sim N(0, \tau^{ability}) \end{equation*}\]

Now we have a hierarchical model!

We can make a similar argument for the day of the week effects:2

2 Can you think of a better model for day effects given the information above??

\[\begin{equation*} day\ effect \sim N(0, \tau^{day}) \end{equation*}\]

Finally we can complete our model by specifying prior distributions for the non-hierarchical parameters:3

3 \(HN\) here refers to the “half-normal” distribution, a decent default prior for hierarchical standard deviations

\[\begin{align*} \mu &\sim LN(0, 1) \\ \tau_ability &\sim HN(0, 1) \\ \tau_day &\sim HN(0, 1) \end{align*}\]

To test out our model with fake data, we can use Python to generate a fake set of salespeople and days, then generate some sales consistently with our model. Next we can generate some data,

from pathlib import Path
import json
import numpy as np
import pandas as pd

N_SALESPERSON = 10
N_WEEK = 2
DAY_NAMES = ["Mon", "Tue", "Wed", "Thu", "Fri"]
BASELINE = 2  # 2 plushies in one day is fine
TAU_ABILITY = 0.35
TAU_DAY = 0.2

SEED = 12345
DATA_DIR = Path("../data")

rng = np.random.default_rng(seed=SEED)

with open(DATA_DIR / "names.json", "r") as f:
    name_directory = json.load(f)

names = [
    f"{first_name} {surname}"
    for first_name, surname in zip(
        *map(
            lambda l: rng.choice(l, size=N_SALESPERSON, replace=False),
            name_directory.values()
        )
    )
]

abilities = rng.normal(loc=0, scale=TAU_ABILITY, size=N_SALESPERSON)

salespeople = pd.DataFrame({"salesperson": names, "ability": abilities})

salespeople
salesperson ability
0 Morten Andersen 0.489643
1 Lene Poulsen 0.462804
2 Rasmus Jensen -0.104894
3 Hanne Madsen 0.316022
4 Mette Rasmussen -0.567554
5 Christian Christensen -0.055366
6 Helle Kristensen 0.157319
7 Charlotte Hansen -0.470260
8 Maria Petersen -0.028591
9 Jette Thomsen 0.603659
day_effects = sorted(
    rng.normal(loc=0, scale=TAU_DAY, size=len(DAY_NAMES))
)[::-1]  # This (i.e. `[::-1]`) is a nice way to reverse a list
days = pd.DataFrame({"day": DAY_NAMES, "day_effect": day_effects})
days
day day_effect
0 Mon 0.523632
1 Tue 0.165727
2 Wed 0.155472
3 Thu -0.191798
4 Fri -0.241878
sales = (
    days
    .merge(salespeople, how="cross")
    .merge(pd.DataFrame({"week":[1, 2, 3, 4]}), how="cross")
    .assign(
        sales=lambda df: rng.poisson(
            np.exp(np.log(BASELINE) + df["ability"] + df["day_effect"])
        )
    )
    [["week", "day", "salesperson", "day_effect", "ability", "sales"]]
    .copy()
)
sales.head()
week day salesperson day_effect ability sales
0 1 Mon Morten Andersen 0.523632 0.489643 10
1 2 Mon Morten Andersen 0.523632 0.489643 3
2 3 Mon Morten Andersen 0.523632 0.489643 4
3 4 Mon Morten Andersen 0.523632 0.489643 4
4 1 Mon Lene Poulsen 0.523632 0.462804 4

Here is the fortnightly sales chart

total_sales = (
    sales.groupby("salesperson")["sales"].sum().sort_values(ascending=False)
)

total_sales.plot(kind="bar", ylabel="Plushies sold", title="Fortnightly sales")

It’s pretty straightforward to represent hierarchical models with Stan, almost like Stan was designed for it!

from cmdstanpy import CmdStanModel

model = CmdStanModel(stan_file="../src/stan/plushies.stan")
print(model.code())
data {
 int<lower=1> N;
 int<lower=1> N_salesperson;
 int<lower=1> N_day;
 array[N] int<lower=1,upper=N_salesperson> salesperson;
 array[N] int<lower=1,upper=N_day> day;
 array[N] int<lower=0> sales;
 int<lower=0,upper=1> likelihood;
}
parameters {
 real log_mu;
 vector[N_salesperson] ability;
 vector[N_day] day_effect;
 real<lower=0> tau_ability;
 real<lower=0> tau_day;
}
transformed parameters {
 vector[N] log_lambda = log_mu + ability[salesperson] + day_effect[day]; 
}
model {
  log_mu ~ normal(0, 1);
  ability ~ normal(0, tau_ability);
  day_effect ~ normal(0, tau_day);
  tau_ability ~ normal(0, 0.5);
  tau_day ~ normal(0, 0.5);
  if (likelihood){
    sales ~ poisson_log(log_lambda);
  }
}
generated quantities {
 real mu = exp(log_mu);
 vector[N] lambda = exp(log_lambda);
 array[N] int yrep = poisson_rng(lambda);
 vector[N] llik; 
 for (n in 1:N){
   llik[n] = poisson_lpmf(sales[n] | lambda[n]);
 }
}

import arviz as az
from stanio.json import process_dictionary

def one_encode(l):
    """One-encode a 1d list-like thing."""
    return dict(zip(l, range(1, len(l) + 1)))


salesperson_codes = one_encode(salespeople["salesperson"])
day_codes = one_encode(days["day"])
data_prior = process_dictionary({
        "N": len(sales),
        "N_salesperson": len(salespeople),
        "N_day": len(days),
        "salesperson": sales["salesperson"].map(salesperson_codes),
        "day": sales["day"].map(day_codes),
        "sales": sales["sales"],
        "likelihood": 0
    }
)
data_posterior = data_prior | {"likelihood": 1}
mcmc_prior = model.sample(data=data_prior)
mcmc_posterior = model.sample(data=data_posterior)
idata = az.from_cmdstanpy(
    posterior=mcmc_posterior,
    prior=mcmc_prior,
    log_likelihood="llik",
    posterior_predictive="yrep",
    observed_data=data_posterior,
    coords={
        "salesperson": salespeople["salesperson"],
        "day": days["day"],
        "observation": sales.index
    },
    dims={
        "lambda": ["observation"],
        "ability": ["salesperson"],
        "day_effect": ["day"],
        "llik": ["observation"],
        "yrep": ["observation"]
    }
)
idata
17:02:44 - cmdstanpy - INFO - CmdStan start processing
17:02:45 - cmdstanpy - INFO - CmdStan done processing.
17:02:45 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in 'plushies.stan', line 22, column 2 to column 35)
Consider re-running with show_console=True if the above output is unclear!
17:02:45 - cmdstanpy - WARNING - Some chains may have failed to converge.
    Chain 1 had 5 divergent transitions (0.5%)
    Chain 2 had 9 divergent transitions (0.9%)
    Chain 3 had 10 divergent transitions (1.0%)
    Chain 4 had 2 divergent transitions (0.2%)
    Use the "diagnose()" method on the CmdStanMCMC object to see further information.
17:02:45 - cmdstanpy - INFO - CmdStan start processing
17:02:46 - cmdstanpy - INFO - CmdStan done processing.
17:02:46 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in 'plushies.stan', line 22, column 2 to column 35)
Consider re-running with show_console=True if the above output is unclear!
                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                
arviz.InferenceData
    • <xarray.Dataset> Size: 13MB
      Dimensions:           (chain: 4, draw: 1000, salesperson: 10, day: 5,
                             log_lambda_dim_0: 200, observation: 200)
      Coordinates:
        * chain             (chain) int64 32B 0 1 2 3
        * draw              (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
        * salesperson       (salesperson) object 80B 'Morten Andersen' ... 'Jette T...
        * day               (day) object 40B 'Mon' 'Tue' 'Wed' 'Thu' 'Fri'
        * log_lambda_dim_0  (log_lambda_dim_0) int64 2kB 0 1 2 3 4 ... 196 197 198 199
        * observation       (observation) int64 2kB 0 1 2 3 4 ... 195 196 197 198 199
      Data variables:
          log_mu            (chain, draw) float64 32kB 0.9334 0.8925 ... 0.7062 0.7893
          ability           (chain, draw, salesperson) float64 320kB 0.283 ... 0.5017
          day_effect        (chain, draw, day) float64 160kB 0.3475 ... -0.03229
          tau_ability       (chain, draw) float64 32kB 0.3533 0.3118 ... 0.2699 0.3125
          tau_day           (chain, draw) float64 32kB 0.2344 0.1823 ... 0.3317 0.311
          log_lambda        (chain, draw, log_lambda_dim_0) float64 6MB 1.564 ... 1...
          mu                (chain, draw) float64 32kB 2.543 2.441 ... 2.026 2.202
          lambda            (chain, draw, observation) float64 6MB 4.777 ... 3.521
      Attributes:
          created_at:                 2024-04-29T15:02:46.471020
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 6MB
      Dimensions:      (chain: 4, draw: 1000, observation: 200)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
        * observation  (observation) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199
      Data variables:
          yrep         (chain, draw, observation) float64 6MB 4.0 4.0 5.0 ... 4.0 1.0
      Attributes:
          created_at:                 2024-04-29T15:02:46.477255
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 6MB
      Dimensions:      (chain: 4, draw: 1000, observation: 200)
      Coordinates:
        * chain        (chain) int64 32B 0 1 2 3
        * draw         (draw) int64 8kB 0 1 2 3 4 5 6 ... 993 994 995 996 997 998 999
        * observation  (observation) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199
      Data variables:
          llik         (chain, draw, observation) float64 6MB -4.243 -1.877 ... -1.697
      Attributes:
          created_at:                 2024-04-29T15:02:46.814273
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          lp               (chain, draw) float64 32kB 3.48 4.691 ... -1.735 6.093
          acceptance_rate  (chain, draw) float64 32kB 0.9954 0.9843 ... 0.8181 0.9996
          step_size        (chain, draw) float64 32kB 0.1786 0.1786 ... 0.1646 0.1646
          tree_depth       (chain, draw) int64 32kB 4 4 4 4 4 4 4 4 ... 4 4 4 4 4 4 4
          n_steps          (chain, draw) int64 32kB 15 15 31 15 31 ... 15 15 15 15 15
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 4.368 1.675 ... 13.69 6.133
      Attributes:
          created_at:                 2024-04-29T15:02:46.475030
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 26MB
      Dimensions:           (chain: 4, draw: 1000, salesperson: 10, day: 5,
                             log_lambda_dim_0: 200, observation: 200)
      Coordinates:
        * chain             (chain) int64 32B 0 1 2 3
        * draw              (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
        * salesperson       (salesperson) object 80B 'Morten Andersen' ... 'Jette T...
        * day               (day) object 40B 'Mon' 'Tue' 'Wed' 'Thu' 'Fri'
        * log_lambda_dim_0  (log_lambda_dim_0) int64 2kB 0 1 2 3 4 ... 196 197 198 199
        * observation       (observation) int64 2kB 0 1 2 3 4 ... 195 196 197 198 199
      Data variables:
          log_mu            (chain, draw) float64 32kB -0.8982 0.7227 ... -0.09284
          ability           (chain, draw, salesperson) float64 320kB -0.07659 ... -...
          day_effect        (chain, draw, day) float64 160kB -0.6431 ... 0.1662
          tau_ability       (chain, draw) float64 32kB 0.5487 0.2756 ... 0.5001 0.5082
          tau_day           (chain, draw) float64 32kB 1.425 0.3605 ... 0.1924 0.2192
          log_lambda        (chain, draw, log_lambda_dim_0) float64 6MB -1.618 ... ...
          mu                (chain, draw) float64 32kB 0.4073 2.06 ... 0.9009 0.9113
          lambda            (chain, draw, observation) float64 6MB 0.1983 ... 0.8162
          yrep              (chain, draw, observation) float64 6MB 0.0 0.0 ... 2.0 1.0
          llik              (chain, draw, observation) float64 6MB -31.48 ... -1.916
      Attributes:
          created_at:                 2024-04-29T15:02:46.806779
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 204kB
      Dimensions:          (chain: 4, draw: 1000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          lp               (chain, draw) float64 32kB -4.523 3.679 ... 4.492 5.387
          acceptance_rate  (chain, draw) float64 32kB 0.9768 0.987 ... 0.9953 0.9733
          step_size        (chain, draw) float64 32kB 0.1516 0.1516 ... 0.1653 0.1653
          tree_depth       (chain, draw) int64 32kB 4 5 4 4 4 4 4 4 ... 5 5 5 5 5 5 3
          n_steps          (chain, draw) int64 32kB 15 47 15 15 15 ... 31 31 31 31 15
          diverging        (chain, draw) bool 4kB False False False ... False False
          energy           (chain, draw) float64 32kB 11.26 11.58 ... 7.702 -0.3423
      Attributes:
          created_at:                 2024-04-29T15:02:46.810452
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

    • <xarray.Dataset> Size: 10kB
      Dimensions:              (N_dim_0: 1, N_salesperson_dim_0: 1, N_day_dim_0: 1,
                                salesperson_dim_0: 200, day_dim_0: 200,
                                sales_dim_0: 200, likelihood_dim_0: 1)
      Coordinates:
        * N_dim_0              (N_dim_0) int64 8B 0
        * N_salesperson_dim_0  (N_salesperson_dim_0) int64 8B 0
        * N_day_dim_0          (N_day_dim_0) int64 8B 0
        * salesperson_dim_0    (salesperson_dim_0) int64 2kB 0 1 2 3 ... 197 198 199
        * day_dim_0            (day_dim_0) int64 2kB 0 1 2 3 4 ... 195 196 197 198 199
        * sales_dim_0          (sales_dim_0) int64 2kB 0 1 2 3 4 ... 196 197 198 199
        * likelihood_dim_0     (likelihood_dim_0) int64 8B 0
      Data variables:
          N                    (N_dim_0) int64 8B 200
          N_salesperson        (N_salesperson_dim_0) int64 8B 10
          N_day                (N_day_dim_0) int64 8B 5
          salesperson          (salesperson_dim_0) int64 2kB 1 1 1 1 2 ... 10 10 10 10
          day                  (day_dim_0) int64 2kB 1 1 1 1 1 1 1 1 ... 5 5 5 5 5 5 5
          sales                (sales_dim_0) int64 2kB 10 3 4 4 4 5 6 ... 4 1 1 3 3 2
          likelihood           (likelihood_dim_0) int64 8B 1
      Attributes:
          created_at:                 2024-04-29T15:02:46.812614
          arviz_version:              0.17.0
          inference_library:          cmdstanpy
          inference_library_version:  1.2.1

az.summary(idata, var_names="~lambda", filter_vars="regex")
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
log_mu 0.782 0.219 0.389 1.222 0.009 0.007 609.0 773.0 1.01
ability[Morten Andersen] 0.435 0.171 0.121 0.767 0.005 0.004 1126.0 1149.0 1.00
ability[Lene Poulsen] 0.370 0.173 0.062 0.722 0.005 0.004 1193.0 1643.0 1.00
ability[Rasmus Jensen] -0.168 0.183 -0.520 0.157 0.005 0.004 1221.0 1886.0 1.00
ability[Hanne Madsen] 0.100 0.175 -0.222 0.437 0.005 0.004 1321.0 1432.0 1.00
ability[Mette Rasmussen] -0.543 0.213 -0.967 -0.185 0.005 0.004 1552.0 2311.0 1.00
ability[Christian Christensen] -0.126 0.185 -0.460 0.232 0.005 0.004 1384.0 1931.0 1.00
ability[Helle Kristensen] 0.181 0.176 -0.140 0.519 0.005 0.004 1296.0 1544.0 1.00
ability[Charlotte Hansen] -0.282 0.189 -0.653 0.052 0.005 0.004 1369.0 2032.0 1.00
ability[Maria Petersen] -0.214 0.185 -0.580 0.118 0.005 0.003 1601.0 2273.0 1.00
ability[Jette Thomsen] 0.408 0.172 0.066 0.720 0.005 0.004 1046.0 1578.0 1.00
day_effect[Mon] 0.342 0.193 0.019 0.759 0.008 0.005 640.0 838.0 1.00
day_effect[Tue] 0.122 0.191 -0.220 0.524 0.008 0.005 643.0 920.0 1.01
day_effect[Wed] 0.130 0.194 -0.236 0.502 0.008 0.005 654.0 802.0 1.00
day_effect[Thu] -0.305 0.195 -0.655 0.097 0.007 0.006 706.0 845.0 1.01
day_effect[Fri] -0.177 0.193 -0.564 0.186 0.008 0.006 668.0 875.0 1.00
tau_ability 0.393 0.113 0.210 0.600 0.002 0.002 2091.0 2631.0 1.00
tau_day 0.350 0.152 0.123 0.635 0.004 0.003 1512.0 1926.0 1.00
mu 2.240 0.498 1.361 3.204 0.021 0.015 609.0 773.0 1.01

The problem with hierarchical models: funnels

Did you notice that cmdstanpy printed some divergent transition warnings above? This illustrates a pervasive problem with hierarchical models: funnel-shaped marginal posterior distributions. The plot below shows the values of the parameter \(\tau_{day}\) and the corresponding day effect values for Monday in the prior samples:

az.plot_pair(
    idata.prior,
    var_names=["tau_day", "day_effect"],
    coords={"day": ["Mon"]},
);
/Users/tedgro/repos/biosustain/bayesian_statistics_for_computational_biology/.venv/lib/python3.12/site-packages/arviz/plots/pairplot.py:232: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
  gridsize = int(dataset.dims["draw"] ** 0.35)

As we discussed previously, funnels are hard to sample because of their inconsistent characteristic lengths. Unfortunately, they are often inevitable in hierarchical models. Do you get an idea why from the graph?

There are three main solutions to funnels: add more information, tune the HMC algorithm or reparameterise the model.

Add more information

The posterior distribution didn’t have any divergent transitions. This is probably because the extra information in the measurements made it easier to sample. Comparing the marginal distributions from above illustrates how this can happen: note that the difference in scale between the neck and the bowl of the funnel is less extreme for the posterior samples.

from matplotlib import pyplot as plt
f, ax = plt.subplots()
az.plot_pair(
    idata.prior,
    var_names=["tau_day", "day_effect"],
    coords={"day": ["Mon"]},
    ax=ax,
    scatter_kwargs={"label": "prior"},
);
az.plot_pair(
    idata.posterior,
    var_names=["tau_day", "day_effect"],
    coords={"day": ["Mon"]},
    ax=ax,
    scatter_kwargs={"label": "posterior"},
);
ax.legend(frameon=False);
/Users/tedgro/repos/biosustain/bayesian_statistics_for_computational_biology/.venv/lib/python3.12/site-packages/arviz/plots/pairplot.py:232: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
  gridsize = int(dataset.dims["draw"] ** 0.35)
/Users/tedgro/repos/biosustain/bayesian_statistics_for_computational_biology/.venv/lib/python3.12/site-packages/arviz/plots/pairplot.py:232: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
  gridsize = int(dataset.dims["draw"] ** 0.35)

If better measurements aren’t available, divergences can often be avoided by searching for extra information that can justify narrower priors.

Tune the algorithm

Stan allows increasing the length of the warmup phase (iter_warmup, default 2000), bringing the target acceptance probability close to 1 (adapt_delta, default 0.8) and by increasing the leapfrog integrator’s maximum tree depth (max_treedepth, default 10). All of these changes trade speed for reliability.

mcmc_prior_2 = model.sample(
    data=data_prior,
    iter_warmup=3000,
    adapt_delta=0.99,
    max_treedepth=12
)
17:02:47 - cmdstanpy - INFO - CmdStan start processing
17:02:53 - cmdstanpy - INFO - CmdStan done processing.
17:02:53 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in 'plushies.stan', line 23, column 2 to column 34)
Consider re-running with show_console=True if the above output is unclear!
                                                                                                                                                                                                                                                                                                                                

Unfortunately even quite aggressive tuning doesn’t get rid of all the divergent transitions in this case.

Reparameterise

The idea with reparameterisation is to define auxiliary parameters which don’t have problematic relationships, then recover the problematic parameters later.

“Non-centred” parameterisations take a distribution with the form \(\alpha\sim D(\mu,\sigma)\) and express it as follows:

\[\begin{align*} u \sim D(0, 1)\\ \alpha = \mu + u * \sigma \end{align*}\]

model_nc = CmdStanModel(stan_file="../src/stan/plushies-nc.stan")
print(model_nc.code())
data {
 int<lower=1> N;
 int<lower=1> N_salesperson;
 int<lower=1> N_day;
 array[N] int<lower=1,upper=N_salesperson> salesperson;
 array[N] int<lower=1,upper=N_day> day;
 array[N] int<lower=0> sales;
 int<lower=0,upper=1> likelihood;
}
parameters {
 real log_mu;
 vector[N_salesperson] ability_z;
 vector[N_day] day_effect_z;
 real<lower=0> tau_ability;
 real<lower=0> tau_day;
}
transformed parameters {
 vector[N_salesperson] ability = ability_z * tau_ability;
 vector[N_day] day_effect = day_effect_z * tau_day;
 vector[N] log_lambda = log_mu + ability[salesperson] + day_effect[day]; 
}
model {
  log_mu ~ normal(0, 1);
  ability_z ~ normal(0, 1);
  day_effect_z ~ normal(0, 1);
  tau_ability ~ normal(0, 1);
  tau_day ~ normal(0, 1);
  if (likelihood){
    sales ~ poisson_log(log_lambda);
  }
}
generated quantities {
 real mu = exp(log_mu);
 vector[N] lambda = exp(log_lambda);
 array[N] int yrep = poisson_rng(lambda);
 vector[N] llik; 
 for (n in 1:N){
   llik[n] = poisson_lpmf(sales[n] | lambda[n]);
 }
}
mcmc_prior_nc = model.sample(
    data=data_prior,
    iter_warmup=3000,
    adapt_delta=0.999,
    max_treedepth=12
)
17:02:53 - cmdstanpy - INFO - CmdStan start processing
17:02:59 - cmdstanpy - INFO - CmdStan done processing.
17:02:59 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in 'plushies.stan', line 23, column 2 to column 34)
Consider re-running with show_console=True if the above output is unclear!
                                                                                                                                                                                                                                                                                                                                

Beware of using non-centred parameterisation as a default: it isn’t guaranteed to be better.

So how many plushies do I need to sell?

f, ax = plt.subplots()
az.plot_forest(
    np.exp(idata.posterior["log_mu"] + idata.posterior["ability"]),
    kind="forestplot",
    combined=True,
    ax=ax,
    show=False,
);
ax.scatter(
    np.exp(np.log(BASELINE) + salespeople["ability"]), 
    ax.get_yticks()[::-1], 
    color="red", 
    label="True expected sales",
    zorder=2
)
ax.scatter(
    sales.groupby("salesperson")["sales"].mean().reindex(salespeople["salesperson"]), 
    ax.get_yticks()[::-1], 
    color="black", 
    label="Observed sales per day",
    zorder=3
)
ax.set(title="", xlabel="Number of plushies sold per day")
ax.axvline(BASELINE, linestyle="--", label="baseline", linewidth=0.8, color="black")
ax.legend(frameon=False);

Takeaways

  • Hierarchical models are a powerful way to capture structural information
  • You may run into problematic sampling, but you have options!
  • There is surprisingly little information in low-expected-value count data.